//
// Package for reading, writing, and applying Simple File Verification (SFV) files.
//
// Reference:
//   http://en.wikipedia.org/wiki/Simple_file_verification
//

package main

import (
	"bufio"
	"flag"
	"fmt"
	"hash/crc32"
	"io"
	"keeper/elog"
	"os"
	"path/filepath"
	"strconv"
	"strings"
	"time"
)

var flagCrcBufferSize = flag.Int64("crc-buffer-size", 1024*1024,
	"How big, in bytes, should the buffer be for reading during crc32.")

// -- public methods

// Record checksums for all files under 'path'.  The latter is
// traversed recursively.  The checksums are written out in SFV format
// to 'sfv_filename'.  Returns success.
func SfvRecord(path, sfv_filename string) bool {
	// Sanity check
	_, err := filepath.Rel(filepath.Dir(sfv_filename), path)
	if err != nil {
		elog.Println("SFV file not parent directory; will use absolute paths.")
	}
	
	// Scan directory, capturing info.
	var sfv sfvTable
	errors := sfv.scan(path)
	if errors > 0 {
		elog.Println("Unable to proceed due to ERRORs.")
		return false
	}

	// Record table to SFV file.
	// TODO: use "safe overwrite": i) write to tmp; ii) move orig to
	// tmp2; iii) mv tmp to actual; iv) rm tmp2 (use foo.old and foo.new)
	file, err := os.Create(sfv_filename)
	if err != nil {
		elog.Fatal(err)
	}
	defer file.Close()
	wrt := bufio.NewWriter(file)
	defer wrt.Flush()

	// TODO: error-check these writes
	fmt.Fprintln(wrt, "; Generated by keeper.go")
	fmt.Fprintln(wrt, ";")
	for _, entry := range sfv.entries {
		rel_filename, err := filepath.Rel(filepath.Dir(sfv_filename), entry.filename)
		if err != nil {
			// Cannot make path relative to SFV; use absolute path.
			rel_filename = entry.filename
		}
		_, err = fmt.Fprintf(wrt, "%s %08X\n", rel_filename, entry.checksum)
		if err != nil {
			elog.Fatal(err)
		}
	}
	elog.Printf("Wrote '%s'.\n", sfv_filename)
	return true
}

// Verify integrity of files on disk.  The files to check are
// determined by the contents of the SFV file named in 'sfv_filename'.
func SfvVerify(sfv_filename string) (errors []string) {
	start := time.Now()
	var total_bytes int64 = 0

	var sfv sfvTable
	sfv.read(sfv_filename)

	if len(sfv.entries) == 0 {
		elog.Println("No entries captured from SFV file; done.")
		return
	}

	b := NewCrc32Buffer(*flagCrcBufferSize)
	for _, entry := range sfv.entries {
		const max_filename_len = 68  // want full line to be < 80
		fmt.Print(CollapseMiddle(entry.filename, max_filename_len), ": ")
		checksum, ok := b.ComputeFileChecksum(entry.filename)
		if !ok {
			errors = append(errors,
				fmt.Sprintf("Error accessing %s", entry.filename))
			continue
		}
		fi, err := os.Stat(entry.filename)
		if err != nil {
			elog.Fatal(err)
		}
		total_bytes += fi.Size()
		if checksum != entry.checksum {
			err := fmt.Sprintf("CRC mismatch for %s (expected: 0x%08X; was: 0x%08X)",
				entry.filename, entry.checksum, checksum)
			// Print error on separate line.
			fmt.Println("*****")
			fmt.Println("   ", err)
			errors = append(errors, err)
			continue
		}
		fmt.Println("CRC OK")
	}
	duration := time.Now().Sub(start)
	elog.Println("Time elapsed:", duration)
	elog.Printf("Average speed: %5.1f MB/s\n",
		float64(total_bytes) / 1024 / 1024 / duration.Seconds())
	return
}

// -- private methods

type sfvEntry struct {
	filename string
	checksum uint32
}

type sfvTable struct {
	entries []sfvEntry
}

func (s *sfvTable) read(sfv_filename string) {
	file, err := os.Open(sfv_filename)
	if err != nil {
		elog.Fatal(err)
	}
	defer file.Close()

	const maxLineLength = 2048
	rdr := bufio.NewReaderSize(file, maxLineLength)
	done := false
	for !done {
		line, err := rdr.ReadString('\n')
		switch err {
		case nil:
			// Have populated line.
			line = strings.TrimSpace(line)
			if line[0] == ';' {
				// Line is a comment.
				continue
			}
			sepIndex := strings.LastIndex(line, " ")
			if sepIndex < 0 {
				elog.Fatal("Invalid SFV line detected:", line)
			}
			filename := strings.TrimSpace(line[:sepIndex])
			checksum := strings.TrimSpace(line[sepIndex+1:])
			checksum64, err := strconv.ParseUint(checksum, 16, 32)
			var checksum32 uint32 = uint32(checksum64)
			if err != nil {
				elog.Fatal("Trouble parsing CRC checksum string:", checksum)
			}
			// Make filename an absolute path.  The relative paths in
			// SFV file are relative to the location of the SFV file
			// itself.
			filename = filepath.Join(filepath.Dir(sfv_filename), filename)
			s.entries = append(
				s.entries, sfvEntry{filename: filename, checksum: checksum32})
		case io.EOF:
			// Done reading SFV file.
			done = true
		default:
			// Have real error.
			elog.Fatal(err)
		}
	}
}

func (s *sfvTable) write(sfv_filename string) {
	for _, entry := range s.entries {
		elog.Println("Writing entry for " + entry.filename)
	}
}

func (s *sfvTable) scan(path string) (num_errors int) {
	scanner := SimpleScanner{
		sfv: s,
		crc: NewCrc32Buffer(*flagCrcBufferSize),
	    start_time: time.Now(),
	    last_update_time: time.Now()}

	// Approach based on:
	//  http://www.1771.in/closures-how-am-i-meant-to-use-filepath-walk-in-go.html
	var scan = func(path string, f os.FileInfo, err error) error {
		scanner.Visit(path, f)
		return nil
	}

	elog.Printf("Scanning %s...\n", path)
	filepath.Walk(path, scan)
	// TODO: error handling
	
	duration := time.Now().Sub(scanner.start_time)

	elog.Println("")
	elog.Println("Finished scanning.")
	elog.Println("         dirs:", scanner.dir_count)
	elog.Println("        files:", scanner.file_count)
	elog.Println("        other:", scanner.other_count)
	elog.Println("  total bytes:", HumanReadable(scanner.total_bytes, "B"))
	elog.Println("   total time:", duration)
	elog.Printf(" average rate: %.1f MB/s\n",
		float64(scanner.total_bytes) / 1024 / 1024 / duration.Seconds())

	if len(scanner.errors) > 0 {
		num_errors = len(scanner.errors)
		elog.Println("ERRORs:")
		for _, err := range scanner.errors {
			elog.Println("  ", err)
		}
	}
	elog.Println("")
	return
}

// SimpleScanner class
type SimpleScanner struct {
	sfv *sfvTable
	crc *crc32Buffer
	start_time time.Time
	last_update_time time.Time
	last_update_bytes uint64
	next_update time.Time
	total_count uint
	file_count uint
	dir_count uint
	other_count uint
	total_bytes uint64
	errors []string
}

func (s *SimpleScanner) Visit(path string, f os.FileInfo) {
	if f == nil {
		elog.Println("nil FileInfo")
		return
	}

	s.total_count++
	switch {
	case f.Mode().IsDir():
		s.dir_count++
	case f.Mode()&os.ModeType == 0:  // regular file
		checksum, ok := s.crc.ComputeFileChecksum(path)
		if ok {
			s.sfv.entries = append(s.sfv.entries, sfvEntry{path, checksum})
			s.file_count++
			s.total_bytes += uint64(f.Size())
		} else {
			s.errors = append(s.errors,
				fmt.Sprint("Open/read failed for: ", path))
		}
	default:
		s.other_count++
		elog.Printf("encountered filemode %d\n", f.Mode())
	}
	const max_width = 79
	const max_fname_width = 69
	if time.Now().After(s.next_update) {
		now := time.Now()

		// It's time for an update.
		fmt.Printf("\r")
		fmt.Printf(strings.Repeat(" ", max_width) + "\r")
		if len(path) > max_fname_width {
			path = path[:max_fname_width-3] + "..."
		}
		delta_secs := now.Sub(s.last_update_time).Seconds()
		delta_bytes := s.total_bytes - s.last_update_bytes
		// Avoid division by zero.
		if delta_secs > 0 {
			rate := float64(delta_bytes) / 1024 / 1024 / delta_secs
			fmt.Printf("%5.1fMB/s %s", rate, path)
		}

		const updateIntervalSecs = 1
		s.next_update = now.Add(updateIntervalSecs * time.Second)
		s.last_update_time = now
		s.last_update_bytes = s.total_bytes
	}
}

type crc32Buffer struct {
	buffer []byte
}

func NewCrc32Buffer(size int64) *crc32Buffer {
	b := new(crc32Buffer)
	b.buffer = make([]byte, size, size)
	return b
}

func (buffer *crc32Buffer) ComputeFileChecksum(
	path string) (checksum uint32, ok bool) {

	file, err := os.Open(path)
	if err != nil {
		elog.Println(err)
		return
	}
	defer file.Close()

	data := buffer.buffer[:]
	for {
		count, err := file.Read(data)
		if err != nil && err != io.EOF {
			elog.Fatal(err)
		}
		if count == 0 {
			break
		}
		if count < len(data) {
			data = data[:count]
		}
		checksum = crc32.Update(checksum, crc32.IEEETable, data)
	}
	ok = true
	return
}
